/* Copyright (C) 1999-2014 Massachusetts Institute of Technology.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <math.h>

#include "config.h"
#include <check.h>
#include <blasglue.h>
#include <matrices.h>
#include <eigensolver.h>

static sqmatrix A, Ainv, B;

#define MAX(a,b) ((a) > (b) ? (a) : (b))
#define MIN(a,b) ((a) < (b) ? (a) : (b))

extern void Aop(evectmatrix Xin, evectmatrix Xout, void *data,
		int is_current_eigenvector, evectmatrix Work);
extern void Bop(evectmatrix Xin, evectmatrix Xout, void *data,
		int is_current_eigenvector, evectmatrix Work);
extern void Ainvop(evectmatrix Xin, evectmatrix Xout, void *data,
		   evectmatrix Y, real *eigenvals, sqmatrix YtY);
extern void Cop(evectmatrix Xin, evectmatrix Xout, void *data,
		evectmatrix Y, real *eigenvals, sqmatrix YtY);
extern void printmat(scalar *A, int m, int n, int ldn);
extern void printmat_matlab(scalar *A, int m, int n);

extern void debug_check_memory_leaks(void);

real norm_diff(scalar *a, scalar *b, int n)
{
     real bmag = 0.0, diffmag = 0.0;
     int i;
     for (i = 0; i < n; ++i) {
          scalar d;
          ASSIGN_SCALAR(d,
                        SCALAR_RE(b[i]) - SCALAR_RE(a[i]),
                        SCALAR_IM(b[i]) - SCALAR_IM(a[i]));
          bmag += SCALAR_NORMSQR(b[i]);
          diffmag += SCALAR_NORMSQR(d);
     }
     return sqrt(diffmag / bmag);
}

#define NWORK 4

void rand_posdef(sqmatrix A, sqmatrix X)
{
    int i, n = A.p;

     /* fill X with random data */
     for (i = 0; i < n*n; ++i)
	  ASSIGN_SCALAR(X.data[i],
			rand() * 1.0 / RAND_MAX,
			rand() * 1.0 / RAND_MAX);

     /* assign A = adjoint(X) * X to get a Hermitian matrix: */
     sqmatrix_AeBC(A, X, 1, X, 0);

     /* increase diagonal elements of A so that our preconditioner
	has a chance of being useful: */
     for (i = 0; i < n; ++i)
       ASSIGN_SCALAR(A.data[i * n + i],
		     n * SCALAR_RE(A.data[i * n + i]),
		     n * SCALAR_IM(A.data[i * n + i]));
}

int main(int argc, char **argv)
{
     int i, j, n = 0, p, trial;
     sqmatrix X, U, YtY, Bcopy;
     evectmatrix Y, Y2, Ystart, W[NWORK];
     real *eigvals, *eigvals_dense, sum = 0.0;
     int num_iters, nWork = NWORK;
     evectoperator bop = Bop;

     if (argc >= 2)
	  n = atoi(argv[1]);

     srand(argc >= 3 ? atoi(argv[2]) : time(NULL));

     CHECK(n > 0, "illegal argument\nSyntax: eigs_test <n> [<seed>]");

     X = create_sqmatrix(n);
     A = create_sqmatrix(n);
     B = create_sqmatrix(n);
     Bcopy = create_sqmatrix(n);
     Ainv = create_sqmatrix(n);
     U = create_sqmatrix(n);

     rand_posdef(A, X);
     rand_posdef(B, X);

#if 1
     for (i = 0; i < n; ++i) {
         for (j = 0; j < n; ++j) {
             ASSIGN_SCALAR(B.data[i * n + j], 0, 0);
         }
         ASSIGN_SCALAR(B.data[i * n + i], 1, 0);
     }
#endif

     sqmatrix_copy(Ainv, A);
     sqmatrix_invert(Ainv, 1, X);
     
     if (n <= 10) {
	  printf("Solving for eigenvalues of %d x %d matrix: \nA = ", n, n);
	  printmat_matlab(A.data, n, n);
	  printf("B = ");
	  printmat_matlab(B.data, n, n);
     }

     CHK_MALLOC(eigvals_dense, real, n);

     p = MIN(MIN(5, MAX(n/4, 2)), n);
     printf("\nSolving for %d eigenvals out of %d.\n", p, n);
     YtY = create_sqmatrix(p);         
     Y = create_evectmatrix(n, 1, p, n, 0, n);
     Y2 = create_evectmatrix(n, 1, p, n, 0, n);
     Ystart = create_evectmatrix(n, 1, p, n, 0, n);
     for (i = 0; i < nWork; ++i)
         W[i] = create_evectmatrix(n, 1, p, n, 0, n);
     CHK_MALLOC(eigvals, real, p);
         
     for (trial = 0; trial < 2; ++trial) {
         sqmatrix_copy(U, A);
         if (trial > 0) {
             printf("=== Ordinary Eigenproblem ===\n");
             sqmatrix_eigensolve(U, eigvals_dense, X);
             bop = NULL;
             nWork = NWORK - 1;
         }
         else {
             printf("=== Generalized Eigenproblem ===\n");
             sqmatrix_copy(Bcopy, B);
             sqmatrix_gen_eigensolve(U, Bcopy, eigvals_dense, X);
         }
         
         /* The eigenvectors are actually the columns of U'.  Assign U = U': */
         for (i = 0; i < n; ++i)
             for (j = i + 1; j < n; ++j) {
                 scalar dummy;
                 dummy = U.data[i*n + j];
                 U.data[i*n + j] = U.data[j*n + i];
                 U.data[j*n + i] = dummy;
             }
         for (i = 0; i < n * n; ++i)
             ASSIGN_CONJ(U.data[i], U.data[i]);
         
         printf("\nSolved A by dense eigensolver.\nEigenvalues = ");
         for (sum = 0.0, i = 0; i < p; ++i) {
             sum += eigvals_dense[i];
             printf("  %f", eigvals_dense[i]);
         }
         printf("\nEigenvalue sum = %f\n", sum);
         printf("\nEigenvectors are (by column): \n");
         printmat(U.data, n, p, n);
         
         for (i = 0; i < n*p; ++i)
             ASSIGN_REAL(Ystart.data[i], rand() * 1.0 / RAND_MAX);
         
         /* Check inverse Ainvop: */
         Aop(Ystart, Y, NULL, 0, Y2);
         Ainvop(Y, Y2, NULL, Ystart, NULL, U);
         printf("\nDifference |Y - (1/A)*A*Y| / |Y| = %g\n",
                norm_diff(Ystart.data, Y2.data, Y.n * Y.p));
         
         evectmatrix_copy(Y, Ystart);
         eigensolver(Y, eigvals, Aop,NULL, bop,NULL, Cop,NULL, NULL,NULL,
                     W, nWork, 1e-10,&num_iters, EIGS_DEFAULT_FLAGS);
         printf("\nSolved for eigenvectors after %d iterations.\n", num_iters);
         printf("\nEigenvalues = ");
         for (sum = 0.0, i = 0; i < p; ++i) {
             sum += eigvals[i];
             printf("  %f", eigvals[i]);
             CHECK(fabs(eigvals[i]-eigvals_dense[i]) < 1e-5 * eigvals_dense[i],
                   "incorrect eigenvalue"); 
         }
         printf("\nEigenvalue sum = %f\n", sum);
         
         /* Change phase of eigenvectors to match those solved for prev.: */
         for (i = 0; i < p; ++i) {
             scalar phase;
             
             ASSIGN_DIV(phase, U.data[i], Y.data[i]);
             
             for (j = 0; j < n; ++j) {
                 ASSIGN_MULT(Y.data[j*p + i], Y.data[j*p + i], phase);
             }
         }
         
         printf("Eigenvectors are (by column): \n");
         printmat(Y.data, n, p, p);
         evectmatrix_XtX(YtY, Y, U);
         printf("adjoint(Y) * Y:\n");
         printmat(YtY.data, p, p, p);
         
         printf("\nSolving with exact inverse preconditioner...\n");
         evectmatrix_copy(Y, Ystart);
         eigensolver(Y, eigvals, Aop,NULL, bop,NULL, Ainvop,NULL, NULL,NULL,
                     W, nWork, 1e-10, &num_iters, EIGS_DEFAULT_FLAGS);
         printf("Solved for eigenvectors after %d iterations.\n", num_iters);
         printf("\nEigenvalues = ");
         for (sum = 0.0, i = 0; i < p; ++i) {
             sum += eigvals[i];
             printf("  %f", eigvals[i]);
             CHECK(fabs(eigvals[i]-eigvals_dense[i]) < 1e-5 * eigvals_dense[i],
                   "incorrect eigenvalue");
         }
         printf("\nEigenvalue sum = %f\n", sum);
         
         printf("\nSolving without conjugate-gradient...\n");
         evectmatrix_copy(Y, Ystart);
         eigensolver(Y, eigvals, Aop,NULL, bop,NULL, Cop,NULL, NULL,NULL,
                     W, nWork - 1, 1e-10, &num_iters, EIGS_DEFAULT_FLAGS);
         printf("Solved for eigenvectors after %d iterations.\n", num_iters);
         printf("\nEigenvalues = ");
         for (sum = 0.0, i = 0; i < p; ++i) {
             sum += eigvals[i];
             printf("  %f", eigvals[i]);
             CHECK(fabs(eigvals[i]-eigvals_dense[i]) < 1e-5 * eigvals_dense[i],
                   "incorrect eigenvalue");
         }
         printf("\nEigenvalue sum = %f\n", sum);
         
         printf("\nSolving without preconditioning...\n");
         evectmatrix_copy(Y, Ystart);
         eigensolver(Y, eigvals, Aop,NULL, bop,NULL, NULL,NULL, NULL,NULL,
                     W, nWork, 1e-10, &num_iters, EIGS_DEFAULT_FLAGS);
         printf("Solved for eigenvectors after %d iterations.\n", num_iters);
         printf("\nEigenvalues = ");
         for (sum = 0.0, i = 0; i < p; ++i) {
             sum += eigvals[i];
             printf("  %f", eigvals[i]);
             CHECK(fabs(eigvals[i]-eigvals_dense[i]) < 1e-5 * eigvals_dense[i],
                   "incorrect eigenvalue");
         }
         printf("\nEigenvalue sum = %f\n", sum);
         
         printf("\nSolving without conjugate-gradient or preconditioning...\n");
         evectmatrix_copy(Y, Ystart);
         eigensolver(Y, eigvals, Aop,NULL, bop,NULL, NULL,NULL, NULL,NULL,
                     W, nWork - 1, 1e-10, &num_iters, EIGS_DEFAULT_FLAGS);
         printf("Solved for eigenvectors after %d iterations.\n", num_iters);
         printf("\nEigenvalues = ");
         for (sum = 0.0, i = 0; i < p; ++i) {
             sum += eigvals[i];
             printf("  %f", eigvals[i]);
             CHECK(fabs(eigvals[i]-eigvals_dense[i]) < 1e-5 * eigvals_dense[i],
                   "incorrect eigenvalue");
         }
         printf("\nEigenvalue sum = %f\n", sum);
     }

     destroy_sqmatrix(A);
     destroy_sqmatrix(B);
     destroy_sqmatrix(Bcopy);
     destroy_sqmatrix(Ainv);
     destroy_sqmatrix(X);
     destroy_sqmatrix(U);
     destroy_sqmatrix(YtY);     
     destroy_evectmatrix(Y);
     destroy_evectmatrix(Y2);
     destroy_evectmatrix(Ystart);
     for (i = 0; i < nWork; ++i)
	  destroy_evectmatrix(W[i]);

     free(eigvals);
     free(eigvals_dense);

     debug_check_memory_leaks();

     return EXIT_SUCCESS;
}

void Aop(evectmatrix Xin, evectmatrix Xout, void *data,
	 int is_current_eigenvector, evectmatrix Work)
{
     CHECK(A.p == Xin.n && A.p == Xout.n && Xin.p == Xout.p,
	   "matrices not conformant");

     blasglue_gemm('N', 'N', Xout.n, Xout.p, Xin.n,
		   1.0, A.data, A.p, Xin.data, Xin.p, 0.0, Xout.data, Xout.p);
}

void Bop(evectmatrix Xin, evectmatrix Xout, void *data,
	 int is_current_eigenvector, evectmatrix Work)
{
     CHECK(B.p == Xin.n && B.p == Xout.n && Xin.p == Xout.p,
	   "matrices not conformant");

     blasglue_gemm('N', 'N', Xout.n, Xout.p, Xin.n,
		   1.0, B.data, B.p, Xin.data, Xin.p, 0.0, Xout.data, Xout.p);
}

void Ainvop(evectmatrix Xin, evectmatrix Xout, void *data,
	    evectmatrix Y, real *eigenvals, sqmatrix YtY)
{
     CHECK(Ainv.p == Xin.n && Ainv.p == Xout.n && Xin.p == Xout.p,
	   "matrices not conformant");

     blasglue_gemm('N', 'N', Xout.n, Xout.p, Xin.n,
		   1.0, Ainv.data, Ainv.p,
		   Xin.data, Xin.p, 0.0, Xout.data, Xout.p);
}

void Cop_old(evectmatrix Xin, evectmatrix Xout, void *data,
	 evectmatrix Y, real *eigenvals, sqmatrix YtY)
{
     int in, ip;

     CHECK(A.p == Xin.n && A.p == Xout.n && Xin.p == Xout.p,
           "matrices not conformant");

     evectmatrix_XeYS(Xout, Xin, YtY, 1);

     for (in = 0; in < Xout.n; ++in) {
	  real diag;
	  
	  diag = SCALAR_NORMSQR(A.data[in * A.p + in]);
	  diag = (diag == 0.0) ? 1.0 : 1.0 / sqrt(diag);
	  
	  for (ip = 0; ip < Xout.p; ++ip) {
	       scalar xin = Xout.data[in * Xout.p + ip];
	       ASSIGN_SCALAR(Xout.data[in * Xout.p + ip],
			     diag * SCALAR_RE(xin),
			     diag * SCALAR_IM(xin));
	  }
     }
}

void Cop(evectmatrix Xin, evectmatrix Xout, void *data,
	  evectmatrix Y, real *eigenvals, sqmatrix YtY)
{
     int in, ip;

     CHECK(A.p == Xin.n && A.p == Xout.n && Xin.p == Xout.p,
           "matrices not conformant");

     evectmatrix_XeYS(Xout, Xin, YtY, 1);

     for (in = 0; in < Xout.n; ++in) {
	  scalar diag = A.data[in * A.p + in];
	  
	  for (ip = 0; ip < Xout.p; ++ip) {
	       scalar scale;
	       if (eigenvals) {
		    ASSIGN_SCALAR(scale,
				  SCALAR_RE(diag) - 0*eigenvals[ip],
				  SCALAR_IM(diag));
	       }
	       else
		    scale = diag;
	       ASSIGN_DIV(Xout.data[in * Xout.p + ip],
			  Xout.data[in * Xout.p + ip],
			  scale);
	  }
     }
}

void printmat(scalar *A, int m, int n, int ldn)
{
  int i, j;

  for (i = 0; i < m; ++i) {
       for (j = 0; j < n; ++j) {
#ifdef SCALAR_COMPLEX
	 printf("  (%6.3f,%6.3f)", A[i*ldn + j].re, A[i*ldn + j].im);
#else
	 printf("  %6.3f", A[i*ldn + j]);
#endif

	 if (j > 7) {
	   printf("  ...");
	   break;
	 }
       }
       printf("\n");
       if (i > 7) {
	 printf("   ...\n");
	 break;
       }
  }
}

void printmat_matlab(scalar *A, int m, int n)
{
  int i, j;

  printf("[");
  for (i = 0; i < m; ++i) {
       for (j = 0; j < n; ++j) {
#ifdef SCALAR_COMPLEX
	 printf("  %g+%gi", A[i*n + j].re, A[i*n + j].im);
#else
	 printf("  %g", A[i*n + j]);
#endif
       }
    printf(";\n");
  }
  printf("]\n");
}
